import pingouin as pg
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.gridspec import GridSpec
from matplotlib.ticker import FormatStrFormatter

from config import Config, FiguresConfig, DataConfig
from dataProcessing import DataProcessing
from utils import Utils


class PaperVisualizationCreator:
    @staticmethod
    def vis_angular_size_and_viewing_angle():
        sns.set_style(FiguresConfig.SnsStyle)
        data = DataProcessing.get_trials_answers()

        fig_size = (
            FiguresConfig.HalfColumn,
            FiguresConfig.PageHeight * 0.15
        )
        margin = FiguresConfig.calc_margin_single(fig_size=fig_size, absolute_margin=(1.31, 1.05, 0.1, 0.875))
        spacing = FiguresConfig.calc_spacing_complete(fig_size=fig_size, absolute_spacing=0.265)

        fig = plt.figure(
            figsize=fig_size,
            dpi=200
        )
        gs = GridSpec(
            nrows=1,
            ncols=2,
            figure=fig,
            left=margin["left"],
            top=margin["top"],
            right=margin["right"],
            bottom=margin["bottom"],
            wspace=fig_size[0] * spacing["vertical"],
            hspace=fig_size[1] * spacing["horizontal"],
        )

        legend_set = False
        for i, key in enumerate(["angularSize", "viewingAngle"]):
            ax = fig.add_subplot(gs[0, i])

            for c_key in FiguresConfig.Palettes["conditionKey"].keys():
                sns.kdeplot(
                    ax=ax,
                    data=data[data["conditionKey"] == c_key],
                    x=key,
                    hue="conditionKey",
                    palette=FiguresConfig.Palettes["conditionKey"],
                    hue_order=FiguresConfig.Palettes["conditionKey"].keys(),
                    common_norm=False,
                    linewidth=FiguresConfig.LineWidth,
                    bw_adjust=0.8,
                    ls=FiguresConfig.DashStyles["content"][c_key.split("-")[1]],
                )

            x_range = DataConfig.FixedValues[key][-1] - DataConfig.FixedValues[key][0]
            ax.set_xlim(
                left=DataConfig.FixedValues[key][0] - x_range * FiguresConfig.RangeLimitFactor,
                right=DataConfig.FixedValues[key][-1] + x_range * FiguresConfig.RangeLimitFactor
            )
            ax.set_xticks(DataConfig.FixedValues[key])
            ax.set_xlabel(f"{DataConfig.ParameterLabel[key]} ({DataConfig.ParameterAbbreviation[key]}) in {DataConfig.ParameterDimensions[key]}")

            ax.set_ylabel("Kernel Density")

            if not legend_set:
                Utils.set_legend_from_vis_with_ls(fig, ax, y_pos=1.02, label_on_fig=True, special_legend=True, ls=["-", "--", "-", "--"])
                legend_set = True

            if ax.get_legend() is not None:
                ax.get_legend().remove()

        fig.tight_layout()
        for format_str in FiguresConfig.VisFormats:
            fig.savefig(f"{Config.PaperContentOutputPath}/angularSize and viewingAngle.{format_str}")
        plt.close(fig)

    @staticmethod
    def vis_parameter_over_parts():
        sns.set_style(FiguresConfig.SnsStyle)
        data = DataProcessing.get_trials_answers()
        key_order = ["distance", "tilt", "size", "angularSize", "viewingAngle"]

        fig_size = (
            FiguresConfig.HalfColumn,
            FiguresConfig.PageHeight * 0.275
        )
        margin = FiguresConfig.calc_margin_single(fig_size=fig_size, absolute_margin=(1.275, 1.0, 0.1, 1.4))
        spacing = FiguresConfig.calc_spacing_single(fig_size=fig_size, absolute_spacing=(0.125, 0.2))  # x4, x2

        fig = plt.figure(
            figsize=fig_size,
            dpi=200
        )
        gs = GridSpec(
            nrows=len(key_order),
            ncols=3,
            figure=fig,
            left=margin["left"],
            top=margin["top"],
            right=margin["right"],
            bottom=margin["bottom"],
            wspace=fig_size[0] * spacing["vertical"],
            hspace=fig_size[1] * spacing["horizontal"],
        )

        # Create empty rows with only the y axis visible.
        y_axs = []
        for i, key in enumerate(key_order):
            ax = fig.add_subplot(gs[i, :])
            y_axs.append(ax)

            vis = sns.lineplot(
                ax=ax
            )

            ax.get_xaxis().set_visible(False)

            y_range = DataConfig.FixedValues[key][-1] - DataConfig.FixedValues[key][0]
            ax.set_ylim(
                bottom=DataConfig.FixedValues[key][0] - y_range * FiguresConfig.RangeLimitFactor,
                top=DataConfig.FixedValues[key][-1] + y_range * FiguresConfig.RangeLimitFactor
            )
            ax.set_yticks(DataConfig.FixedValues[key])
            ax.set_yticklabels([DataConfig.FixedValues[key][i] if i % 2 == 0 else "" for i in range(5)])
            ax.set_ylabel(f"{DataConfig.ParameterAbbreviation[key]} in {DataConfig.ParameterDimensions[key]}")

            ax.spines['top'].set_visible(False)
            ax.spines['bottom'].set_visible(False)
            ax.grid(b=None)

        fig.align_ylabels(y_axs)

        # Create empty columns with only the x axis visible.
        x_axs = []
        for part in range(1, 4):
            ax = fig.add_subplot(gs[:, part - 1])
            x_axs.append(ax)

            vis = sns.lineplot(
                ax=ax
            )

            ax.get_yaxis().set_visible(False)

            x_key = ["size", "tilt", "distance"][part - 1]
            ax.set_xlim(
                left=0 - 5 * FiguresConfig.RangeLimitFactor,
                right=4 + 5 * FiguresConfig.RangeLimitFactor
            )
            ax.set_xticks(list(range(0, 5)))
            ax.set_xticklabels(DataConfig.FixedValues[x_key])
            ax.set_xlabel(f"fixed {DataConfig.ParameterAbbreviation[x_key]} in {DataConfig.ParameterDimensions[x_key]}")

            for label in ax.get_xticklabels():
                label.set_rotation(90)
                label.set_ha('center')
                # label.set_va('center')

            # ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)
            ax.spines['left'].set_visible(False)
            ax.grid(b=None)

        fig.align_xlabels(x_axs)

        # Create the content of each of the cells.
        for part in range(1, 4):
            for i, key in enumerate(key_order):
                ax = fig.add_subplot(gs[i, part - 1])

                # Create a vis, if possible.
                if not (part == 1 and key == "size" or part == 2 and key == "tilt" or part == 3 and key == "distance"):
                # if True:
                    vis = sns.lineplot(
                        ax=ax,
                        data=data[data["part"] == part],
                        x="fixedValueIndex",
                        y=key,
                        style="conditionKey",
                        hue="conditionKey",
                        hue_order=FiguresConfig.Palettes["conditionKey"].keys(),
                        palette=FiguresConfig.Palettes["conditionKey"],
                        # markers=True,
                        dashes=FiguresConfig.DashStyles["conditionKey"],
                        err_style=None,
                        # estimator=None,
                        linewidth=FiguresConfig.LineWidth,
                    )
                # Create an empty plot (for only the grid lines) when the key is equal to the fixed value on this column.
                else:
                    vis = sns.lineplot(
                        ax=ax
                    )

                ax.spines['top'].set_visible(False)
                ax.spines['right'].set_visible(False)
                ax.spines['bottom'].set_visible(False)
                ax.spines['left'].set_visible(False)

                ax.set_xlim(
                    left=0 - 5 * FiguresConfig.RangeLimitFactor,
                    right=4 + 5 * FiguresConfig.RangeLimitFactor
                )
                ax.set_xticks(list(range(0, 5)))
                ax.set_xticklabels(["" for _ in range(5)])
                ax.set_xlabel(None)

                y_range = DataConfig.FixedValues[key][-1] - DataConfig.FixedValues[key][0]
                ax.set_ylim(
                    bottom=DataConfig.FixedValues[key][0] - y_range * FiguresConfig.RangeLimitFactor,
                    top=DataConfig.FixedValues[key][-1] + y_range * FiguresConfig.RangeLimitFactor
                )
                ax.set_yticks(DataConfig.FixedValues[key])
                ax.set_yticklabels(["" for _ in range(len(DataConfig.FixedValues[key]))])
                ax.set_ylabel(None)

                if ax.get_legend() is not None:
                    ax.get_legend().remove()

        # Utils.set_legend_from_vis(fig, ax, y_pos=1.02, label_on_fig=True)
        Utils.set_legend_from_vis_with_ls(fig, ax, y_pos=1.02, label_on_fig=True, ls=["-", "--", "-", "--"])

        fig.tight_layout()
        for format_str in FiguresConfig.VisFormats:
            fig.savefig(f"{Config.PaperContentOutputPath}/fixed parameters to parameters and parts.{format_str}")
        plt.close(fig)


if __name__ == '__main__':
    FiguresConfig.set_rc_plot_values()

    PaperVisualizationCreator.vis_angular_size_and_viewing_angle()
    PaperVisualizationCreator.vis_parameter_over_parts()
